from typing import Dict, Any, List, Tuple
import logging
import numpy as np
from hypersense.optimizer.base_optimizer import BaseOptimizer

import ConfigSpace as CS
import ConfigSpace.hyperparameters as CSH
import hpbandster.core.nameserver as hpns
import hpbandster.optimizers.bohb as BOHB
import hpbandster.core.worker as hpworker

import time
import threading
import uuid

LOGGER = logging.getLogger(__name__)


class BOHBWorker(hpworker.Worker):
    def __init__(self, objective_fn, verbose: bool = False, **kwargs):
        super().__init__(**kwargs)
        self.objective_fn = objective_fn
        self.verbose = verbose

    def compute(self, config, budget, **kwargs):
        """Return dict with 'loss' (to minimize) and optional 'info'."""
        try:
            config = self._sanitize_config(config)
            score = self.objective_fn(config, budget)  # User should use budget in objective_fn
            if self.verbose:
                LOGGER.info(f"[BOHBWorker] budget={budget}, score={score:.6f}")
            return {"loss": 1.0 - score, "info": {}}
        except Exception as e:
            LOGGER.error(f"[Worker Error] During compute: {e}")
            return {"loss": 1.0, "info": {"error": str(e)}}

    def _sanitize_config(self, config):
        """Convert numpy scalars/arrays to native Python types."""
        new_config = {}
        for k, v in config.items():
            if isinstance(v, (np.integer, np.floating)):
                new_config[k] = v.item()
            elif isinstance(v, np.ndarray):
                new_config[k] = v.tolist()
            elif isinstance(v, np.str_):
                new_config[k] = str(v)
            else:
                new_config[k] = v
        return new_config


class BOHBOptimizer(BaseOptimizer):
    def __init__(
        self,
        min_budget: float = 1,
        max_budget: float = 9,
        eta: int = 3,
        **kwargs,
    ):
        """
        Args (inherited from BaseOptimizer via **kwargs):
          - space: Dict[str, dist], objective_fn: Callable, max_trials: int, seed: int, verbose: bool, mode: 'max'|'min'
        """
        super().__init__(**kwargs)
        self.min_budget = float(min_budget)
        self.max_budget = float(max_budget)
        self.eta = int(eta)
        self.result = None
        self.elapsed_time = None

        # Logging level: ERROR if verbose=False; INFO if verbose=True
        self._log_level_quiet = logging.ERROR
        self._log_level_verbose = logging.INFO

    # ---------- internal: logging control ----------
    def _configure_hpbandster_logging(self):
        """Set HPBandSter loggers according to self.verbose."""
        level = self._log_level_verbose if getattr(self, "verbose", False) else self._log_level_quiet
        names = [
            "hpbandster",
            "hpbandster.core",
            "hpbandster.core.nameserver",
            "hpbandster.core.dispatcher",
            "hpbandster.core.worker",
            "hpbandster.core.master",
            "hpbandster.optimizers",
            "hpbandster.optimizers.bohb",
        ]
        for n in names:
            lg = logging.getLogger(n)
            lg.setLevel(level)
            lg.propagate = False
            if not lg.handlers:
                lg.addHandler(logging.NullHandler())

    # ---------- internal: ConfigSpace ----------
    def _build_configspace(self) -> CS.ConfigurationSpace:
        cs = CS.ConfigurationSpace()

        if callable(self.space):
            raise ValueError("BOHB currently requires a ConfigSpace-like dict, not define-by-run.")

        for name, dist in self.space.items():
            if isinstance(dist, tuple) and len(dist) == 3:
                # (low, high, logscale: bool)
                low, high, log = dist
                cs.add(CSH.UniformFloatHyperparameter(name, lower=low, upper=high, log=bool(log)))
            elif isinstance(dist, list):
                cs.add(CSH.CategoricalHyperparameter(name, choices=dist))
            else:
                raise ValueError(f"Unsupported distribution format for {name}: {dist}")
        return cs

    # ---------- public: optimize ----------
    def optimize(self) -> List[Tuple[Dict[str, Any], Any, float]]:
        """
        Run BOHB and return a list of (config, score, elapsed_time) for each finished trial.
        score = 1 - loss  (BOHB minimizes loss)
        """
        # Ensure trial_history exists
        if not hasattr(self, "trial_history"):
            self.trial_history = []

        # Control HPBandSter logging output
        self._configure_hpbandster_logging()

        if not self.objective_fn:
            raise ValueError("Objective function must be provided.")

        # NameServer
        NS = hpns.NameServer(run_id=str(uuid.uuid4()), host="localhost", port=None)
        NS.start()

        cs = self._build_configspace()

        # Worker thread
        worker = BOHBWorker(
            run_id=NS.run_id,
            nameserver="localhost",
            nameserver_port=NS.port,
            objective_fn=self._wrap_objective(),
            verbose=getattr(self, "verbose", False),
        )
        worker_thread = threading.Thread(target=worker.run)
        worker_thread.start()

        # BOHB instance
        bohb = BOHB.BOHB(
            configspace=cs,
            run_id=NS.run_id,
            nameserver="localhost",
            nameserver_port=NS.port,
            min_budget=self.min_budget,
            max_budget=self.max_budget,
            eta=self.eta,
        )

        # Run
        start = time.time()
        self.result = bohb.run(n_iterations=self.max_trials)
        bohb.shutdown(shutdown_workers=True)
        NS.shutdown()
        self.elapsed_time = time.time() - start

        # Collect results (sorted by finish time)
        jobs = list(self.result.get_all_runs())
        id2conf = self.result.get_id2config_mapping()
        jobs_sorted = sorted(jobs, key=lambda j: j.time_stamps.get("finished", 0))

        for job in jobs_sorted:
            start_ts = job.time_stamps.get("started")
            end_ts = job.time_stamps.get("finished")
            elapsed = (end_ts - start_ts) if (start_ts is not None and end_ts is not None) else None

            cfg = id2conf[job.config_id]["config"]
            score = None if job.loss is None else (1.0 - job.loss)

            self.trial_history.append((cfg, score, elapsed))

        return self.trial_history

    # ---------- internal: wrap objective ----------
    def _wrap_objective(self):
        """Wrap objective_fn to accept (config, budget)."""
        def wrapped(config, budget):
            if self.objective_fn:
                return self.objective_fn(config, budget)
            raise ValueError("Objective function is missing.")
        return wrapped

    # ---------- public: get best config ----------
    def get_best_config(self, include_score: bool = False) -> Dict[str, Any]:
        if self.result is None:
            raise ValueError("No optimization has been run yet.")

        inc_id = self.result.get_incumbent_id()     # config_id of incumbent
        id2conf = self.result.get_id2config_mapping()
        runs = self.result.get_runs_by_id(inc_id)   # possibly multiple budgets

        best_run = None
        best_loss = float("inf")
        for r in runs:
            if r.loss is not None and r.loss < best_loss:
                best_loss = r.loss
                best_run = r

        cfg = id2conf[inc_id]["config"]
        score = None if (best_run is None or best_run.loss is None) else (1.0 - best_run.loss)

        if include_score:
            return {
                "params": cfg,
                "score": score,
                "elapsed_time": (round(self.elapsed_time, 4) if self.elapsed_time else None),
            }
        return cfg
